# coding=utf-8
import json

import paramiko


class SSHClient:
    """ 对 paramiko.SSHClient 的封装 """

    def __init__(self, cli: paramiko.SSHClient):
        self.cli = cli
        self.sftp_cli = None

    @classmethod
    def from_pool(cls, pool, host_id):
        # todo: ssh transport 连接池待实现
        return None

    @classmethod
    def from_param(cls, host, port, user, password):
        cli = paramiko.SSHClient()
        cli.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
        try:
            cli.connect(host, port=port, username=user, password=password, timeout=5, auth_timeout=10)
        except TimeoutError as e:
            raise TimeoutError(f"connect to {host} timeout") from e
        except Exception as e:
            raise ConnectionError(f"connect to {host} exception: {e}") from e

        return cls(cli)

    def exec_command(self, command, **kwargs):
        return self.cli.exec_command(command, **kwargs)

    def upload(self, local, remote):
        """
        :param local: str of file path or opened file or file-like object
        :param remote: str
        :return: None
        """
        if self.sftp_cli is None:
            self.sftp_cli = paramiko.SFTPClient.from_transport(self.cli.get_transport())

        if isinstance(local, str):
            self.sftp_cli.put(local, remote)
        else:
            self.sftp_cli.putfo(local, remote)

    def download(self, remote, local):
        if self.sftp_cli is None:
            self.sftp_cli = paramiko.SFTPClient.from_transport(self.cli.get_transport())
        self.sftp_cli.get(remote, local)

    def close(self):
        self.cli.close()


class CommandMixin:
    """ 主机常用命令封装 """

    def __init__(self, cli: SSHClient):
        self.cli = cli

    def hostname(self) -> str:
        std_in, stdout, std_err = self.cli.exec_command("hostname")

        data = stdout.read().strip()
        if isinstance(data, bytes):
            data = data.decode("utf-8")
        return data

    def os_release(self) -> str:
        std_in, stdout, std_err = self.cli.exec_command("cat /etc/os-release")

        data = {}
        for i in stdout.readlines():
            if "=" in i:
                key, value = i.split("=", maxsplit=1)
                data[key] = value

        if "NAME" in data and "VERSION" in data:
            return data["NAME"].strip().strip('"') + " " + data["VERSION"].strip().strip('"')
        elif "PRETTY_NAME" in data:
            return data["PRETTY_NAME"].strip().strip('"')
        else:
            return json.dumps(data)

    def serial_number(self) -> str:
        std_in, stdout, std_err = self.cli.exec_command("dmidecode -s system-serial-number")
        return stdout.read().strip().decode("utf-8")

    def system_uuid(self) -> str:
        std_in, stdout, std_err = self.cli.exec_command("dmidecode -s system-uuid")
        return stdout.read().strip().decode("utf-8")

    def cpu(self) -> dict:
        cmd = 'grep "model name"  /proc/cpuinfo'
        std_in, stdout, std_err = self.cli.exec_command(cmd)
        lines = stdout.readlines()

        return {"model": lines[0].split(':')[-1].strip(), "cores": len(lines)}

    def uptime(self) -> str:
        std_in, stdout, std_err = self.cli.exec_command("uptime -s")
        return stdout.read().strip().decode("utf-8")

    def memory(self) -> dict:
        cmd = 'free -h -w |grep "Mem:"'
        std_in, stdout, std_err = self.cli.exec_command(cmd)
        _, total, used, free, shared, buffers, cache, available = stdout.read().decode().split()

        return {"total": total, "free": free, "used": used,
                "shared": shared, "buffers": buffers, "cache": cache, "available": available}

    def filesystem(self) -> [dict]:
        cmd = "df -h -x overlay -x tmpfs -x devtmpfs -T"
        std_in, stdout, std_err = self.cli.exec_command(cmd)

        data = []
        for index, line in enumerate(stdout.readlines()):
            if index == 0:
                continue

            item = line.split()
            data.append({
                "filesystem": item[0],
                "type": item[1],
                "size": item[2],
                "used": item[3],
                "avail": item[4],
                "use_percent": item[5],
                "mounted": item[6]
            })

        return data

    def nics(self) -> [dict]:
        cmd = "lshw -C network"
        std_in, stdout, std_err = self.cli.exec_command(cmd)

        blocks = stdout.read().decode("utf-8").split("*-network")
        data = []
        for i in blocks:
            item = {}
            for j in i.splitlines():
                if ":" not in j:
                    continue

                key, value = j.split(":", maxsplit=1)
                if key:
                    item[key.strip()] = value.strip()

            if item:
                data.append(item)

        return data


class RemoteHost(SSHClient, CommandMixin): pass


if __name__ == '__main__':
    remote_host = RemoteHost.from_param("10.10.236.55", 22, "root", "1jian8Shu)")
    print(remote_host.hostname())
    print(remote_host.os_release())
    print(remote_host.serial_number())
    print(remote_host.filesystem())
    print(remote_host.cpu())
    print(remote_host.memory())
    print(json.dumps(remote_host.nics()))

    remote_host.close()
